{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: The model and training settings do not follow the reference settings\n",
    "# from the paper. The settings are chosen such that the example can easily be\n",
    "# run on a small dataset with a single GPU.\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import PMSNLoss\n",
    "from lightly.models import utils\n",
    "from lightly.models.modules import MaskedVisionTransformerTorchvision\n",
    "from lightly.models.modules.heads import MSNProjectionHead\n",
    "from lightly.transforms import MSNTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PMSN(nn.Module):\n",
    "    def __init__(self, vit):\n",
    "        super().__init__()\n",
    "\n",
    "        self.mask_ratio = 0.15\n",
    "        self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n",
    "        self.projection_head = MSNProjectionHead(384)\n",
    "\n",
    "        self.anchor_backbone = copy.deepcopy(self.backbone)\n",
    "        self.anchor_projection_head = copy.deepcopy(self.projection_head)\n",
    "\n",
    "        utils.deactivate_requires_grad(self.backbone)\n",
    "        utils.deactivate_requires_grad(self.projection_head)\n",
    "\n",
    "        self.prototypes = nn.Linear(256, 1024, bias=False).weight\n",
    "\n",
    "    def forward(self, images):\n",
    "        out = self.backbone(images=images)\n",
    "        return self.projection_head(out)\n",
    "\n",
    "    def forward_masked(self, images):\n",
    "        batch_size, _, _, width = images.shape\n",
    "        seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n",
    "        idx_keep, _ = utils.random_token_mask(\n",
    "            size=(batch_size, seq_length),\n",
    "            mask_ratio=self.mask_ratio,\n",
    "            device=images.device,\n",
    "        )\n",
    "        out = self.anchor_backbone(images=images, idx_keep=idx_keep)\n",
    "        return self.anchor_projection_head(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ViT small configuration (ViT-S/16)\n",
    "vit = torchvision.models.VisionTransformer(\n",
    "    image_size=224,\n",
    "    patch_size=16,\n",
    "    num_layers=12,\n",
    "    num_heads=6,\n",
    "    hidden_dim=384,\n",
    "    mlp_dim=384 * 4,\n",
    ")\n",
    "model = PMSN(vit)\n",
    "# # or use a torchvision ViT backbone:\n",
    "# vit = torchvision.models.vit_b_32(pretrained=False)\n",
    "# model = PMSN(vit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "transform = MSNTransform()\n",
    "# we ignore object detection annotations by setting target_transform to return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = PMSNLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = [\n",
    "    *list(model.anchor_backbone.parameters()),\n",
    "    *list(model.anchor_projection_head.parameters()),\n",
    "    model.prototypes,\n",
    "]\n",
    "optimizer = torch.optim.AdamW(params, lr=1.5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(10):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        views = batch[0]\n",
    "        utils.update_momentum(model.anchor_backbone, model.backbone, 0.996)\n",
    "        utils.update_momentum(\n",
    "            model.anchor_projection_head, model.projection_head, 0.996\n",
    "        )\n",
    "\n",
    "        views = [view.to(device, non_blocking=True) for view in views]\n",
    "        targets = views[0]\n",
    "        anchors = views[1]\n",
    "        anchors_focal = torch.concat(views[2:], dim=0)\n",
    "\n",
    "        targets_out = model.backbone(images=targets)\n",
    "        targets_out = model.projection_head(targets_out)\n",
    "        anchors_out = model.forward_masked(anchors)\n",
    "        anchors_focal_out = model.forward_masked(anchors_focal)\n",
    "        anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n",
    "\n",
    "        loss = criterion(anchors_out, targets_out, model.prototypes.data)\n",
    "        total_loss += loss.detach()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
